{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "经过了第四讲的铺垫，在本讲中，我们来正式书写自注意力模块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn  \n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们能够看出，目前这种方式只是简单实现过去的编码和当前的编码的十分简单的加权和，也即权重都是通过数量的平均获得，在此之前的每一个字符的编码的权重都是一样的，如果**不同的字符可以根据其自身的情况，根据其自身可以发挥的不同作用**从而生成不同权重的话，，， 没错，这就是自注意力机制做的事情"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 自注意力机制可以被简单地描述为这么一回事\n",
    "     - 对于每一个被编码的位置，其会有两个向量，一个我们称之为`key`,而另一个我们称之为`query`\n",
    "     - 我们可以想象，对于每一个`query`，他一直在说**我在寻找什么**\n",
    "     - 而对于每一个`key`，他一直在说**我是什么，我所包含的是什么**\n",
    "\n",
    "- 我们在不同序列之间获取联系的方式就是通过`key`和`query`之间的交互，更具体一点来讲，即通过一个点乘，在`key`和`query`之间实现\n",
    "\n",
    "- 而`key`和`query`之间的点积就变成了权重`weight`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lecture3中我们提到的最后一种: 自注意力的实现方式\n",
    "torch.manual_seed(1337)\n",
    "B,T,C = 4,8,32 # batch, time, channels\n",
    "x = torch.randn(B,T,C)\n",
    "\n",
    "trils = torch.tril(torch.ones(T,T))\n",
    "weight = torch.zeros(T,T)\n",
    "weight = weight.masked_fill(trils == 0,float('-inf'))  # 使所有tril为0的位置都变为无穷大\n",
    "# 然后，我们选择在每行的维度上去使用sotfmax，\n",
    "weight = F.softmax(weight,dim=-1)\n",
    "\n",
    "out = weight @ x\n",
    "\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Query & Key\n",
    "torch.manual_seed(1337)\n",
    "B,T,C = 4,8,32 # batch, time, channels\n",
    "x = torch.randn(B,T,C)\n",
    "\n",
    "\n",
    "# 一个简单的自注意力头的实现\n",
    "head_size = 16  # 指定头的大小\n",
    "\n",
    "# 实例化线性层\n",
    "key = nn.Linear(C,head_size,bias = False)\n",
    "query = nn.Linear(C,head_size = False)\n",
    "\n",
    "# \n",
    "k = key(x)   # (B,T,C) ---> (B,T,16)\n",
    "q = query(x)  # (B,T,C) ---> (B,T,16)\n",
    "\n",
    "weight = q @ k.transpose()   # 将query与key进行点乘  (B,T,16) @ (B,16,T) ---> (B,T,T),得到我们想要的权重\n",
    "\n",
    "trils = torch.tril(torch.ones(T,T))\n",
    "weight = weight.masked_fill(trils == 0,float('-inf'))  # 使所有tril为0的位置都变为无穷大\n",
    "# 然后，我们选择在每行的维度上去使用sotfmax，\n",
    "weight = F.softmax(weight,dim=-1)\n",
    "\n",
    "out = weight @ x\n",
    "\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Query & Key\n",
    "torch.manual_seed(1337)\n",
    "B,T,C = 4,8,32 # batch, time, channels\n",
    "x = torch.randn(B,T,C)\n",
    "\n",
    "\n",
    "# 一个简单的自注意力头的实现\n",
    "head_size = 16  # 指定头的大小\n",
    "\n",
    "# 实例化线性层\n",
    "key = nn.Linear(C,head_size,bias = False)\n",
    "query = nn.Linear(C,head_size = False)\n",
    "\n",
    "# \n",
    "k = key(x)   # (B,T,C) ---> (B,T,16)\n",
    "q = query(x)  # (B,T,C) ---> (B,T,16)\n",
    "\n",
    "weight = q @ k.transpose()   # 将query与key进行点乘  (B,T,16) @ (B,16,T) ---> (B,T,T),得到我们想要的权重\n",
    "\n",
    "trils = torch.tril(torch.ones(T,T))\n",
    "weight = weight.masked_fill(trils == 0,float('-inf'))  # 使所有tril为0的位置都变为无穷大\n",
    "# 然后，我们选择在每行的维度上去使用sotfmax，\n",
    "weight = F.softmax(weight,dim=-1)\n",
    "\n",
    "out = weight @ x\n",
    "\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 8, 16])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Query & Key & Value\n",
    "torch.manual_seed(1337)\n",
    "B,T,C = 4,8,32 # batch, time, channels\n",
    "x = torch.randn(B,T,C)\n",
    "\n",
    "\n",
    "# 一个简单的自注意力头的实现\n",
    "head_size = 16  # 每个自注意力头的大小\n",
    "\n",
    "# 实例化线性层\n",
    "key = nn.Linear(C,head_size,bias = False)\n",
    "query = nn.Linear(C,head_size,bias = False)\n",
    "value = nn.Linear(C,head_size,bias = False)\n",
    "\n",
    "# \n",
    "k = key(x)   # (B,T,C) ---> (B,T,16)\n",
    "q = query(x)  # (B,T,C) ---> (B,T,16)\n",
    "\n",
    "weight = q @ k.transpose(-2,-1)   # 将query与key进行点乘  (B,T,16) @ (B,16,T) ---> (B,T,T),得到我们想要的权重\n",
    "\n",
    "# 根据原版的公式，我们还要做除以headsize的开方\n",
    "weight = weight * head_size ** 0.5\n",
    "\n",
    "trils = torch.tril(torch.ones(T,T))\n",
    "weight = weight.masked_fill(trils == 0,float('-inf'))  # 使所有tril为0的位置都变为无穷大\n",
    "# 然后，我们选择在每行的维度上去使用sotfmax，\n",
    "weight = F.softmax(weight,dim=-1)\n",
    "\n",
    "\n",
    "# 我们让x也经过一个线性层进行分头 ，对于这里的value 我们可以理解为将x进行剥皮，去发现其本质是什么东西，从而更好的来利用q和k\n",
    "x = value(x)\n",
    "out = weight @ x\n",
    "\n",
    "\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],\n",
       "        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],\n",
       "        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],\n",
       "        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],\n",
       "       grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 此时，我们可以看到之前每个编码的权重变得不再一样了\n",
    "weight[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 注意力本身是一种通信机制，可以将其视为在一个有向的图中，每个节点都会有指向其他节点的边，同时边的权重还是不同的。\n",
    "2. 注意力其实并没有空间的概念，可以将数字的先后想象成一个高维度的向量，向量此时如果进行顺序的变换其实是不会影响结果的，"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
